{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import progressbar\n",
    "\n",
    "from preprocessing import *\n",
    "from crfsuite_model import *\n",
    "from submit import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "level = 'char'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1. 预处理\n",
    "预处理只是简单的去除存在跨行的命名实体的多个起始位置，只保留前后两个数字存成csv格式\n",
    "\n",
    "例如糖尿\\n病 ann位置为`1 2;3 4`，在csv中改为`1 4`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1. Preprocessing ...\n"
     ]
    }
   ],
   "source": [
    "print('1. Preprocessing ...')\n",
    "clean_ann()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2. 生成单字的训练集和测试集\n",
    "这里由于是最后代码审核，就直接使用全量数据集训练了"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2. Generating train/test set from raw text ...\n"
     ]
    }
   ],
   "source": [
    "print('2. Generating train/test set from raw text ...')\n",
    "generate_char_level_train_set()\n",
    "generate_char_level_test_set()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3. 载入训练集测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3. Loading train data ...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100% (363 of 363) |######################| Elapsed Time: 0:01:02 Time:  0:01:02\n"
     ]
    }
   ],
   "source": [
    "print('3. Loading train data ...')\n",
    "train_sents = load_data('./data/' + level + '_level_train_set/', level)\n",
    "x_train = [sent2features(s, level) for s in progressbar.progressbar(train_sents)]\n",
    "y_train = [sent2labels(s, level) for s in train_sents]\n",
    "del train_sents  # crf模型吃内存，显示清理能减少内存占用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4. 训练CRF模型\n",
    "使用pycrfsuite训练CRF模型，时间需求约一天，模型约300M，要求16G内存（Windows系统内存需求可能更高点）。\n",
    "\n",
    "主要可调参数为C1、C2以及迭代次数，由于训练时间较长，这里就不执行了。\n",
    "\n",
    "可以在model.train()中调整模型存放目录。我们训练好了一个crf模型，请在readme中查看。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('4. Training CRF model ...')\n",
    "model = pycrfsuite.Trainer(verbose=True)\n",
    "for xseq, yseq in zip(x_train, y_train):\n",
    "    model.append(xseq, yseq)\n",
    "model.set_params({'c1': 1e-3,  # coefficient for L1 penalty\n",
    "                  'c2': 1,  # coefficient for L2 penalty\n",
    "                  'max_iterations': 500,  # stop earlier\n",
    "                  # include transitions that are possible, but not observed\n",
    "                  'feature.possible_transitions': True})\n",
    "# 'feature.minfreq': 3})\n",
    "model.train('./model/tianchi_ner_' + level + '_level.crfsuite')\n",
    "del x_train\n",
    "del y_train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step5. 生成测试结果\n",
    "这里使用我们训练好的模型展示结果，可以在tagger.open()重新指定模型目录"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5. Generating test result ...\n"
     ]
    }
   ],
   "source": [
    "print('5. Generating test result ...')\n",
    "tagger = pycrfsuite.Tagger()\n",
    "tagger.open('./model/tianchi_ner_char_level_1000.crfsuite')\n",
    "\n",
    "directory = './data/' + level + '_level_test_set/'\n",
    "filenames = os.listdir(directory)\n",
    "for filename in filenames:\n",
    "    if filename.endswith('.csv'):\n",
    "        sent = []\n",
    "        df = pd.read_csv(directory + filename, skip_blank_lines=False)\n",
    "        for index, row in df.iterrows():\n",
    "            if level == 'word':\n",
    "                sent.append((row['seq'], row['pos']))\n",
    "            elif level == 'char':\n",
    "                sent.append((row['char']))\n",
    "        x_test = sent2features(sent, level)\n",
    "        y_test = tagger.tag(x_test)\n",
    "\n",
    "        f = codecs.open('./data/submit/' + filename[:-4] + '.ann', 'w')\n",
    "        id = 0\n",
    "        start_end_pairs = [[0, 0]]  # 用列表的形式存储起始字符\n",
    "        flag = False\n",
    "        entity = ''  # 用于字符串拼接\n",
    "        for i in range(0, len(y_test)):\n",
    "            if y_test[i] != 'O':\n",
    "                if y_test[i][:1] == 'B' and flag == False:  # 若满足条件，则代表是实体的开始\n",
    "                    entity = sent[i][0]\n",
    "                    flag = True\n",
    "                    start_end_pairs[-1][0] = start_end_pairs[-1][1]\n",
    "                    del (start_end_pairs[:-1])  # B重新开始记录起止位置，所以删除前面的元素\n",
    "                elif y_test[i][:1] == 'B' and flag == True:  # 代表上一个实体刚结束，新的实体紧接着开始\n",
    "                    id += 1\n",
    "                    write_format(f, id, y_test[i - 1][2:], start_end_pairs, entity)\n",
    "                    start_end_pairs[-1][0] = start_end_pairs[-1][1]\n",
    "                    del (start_end_pairs[:-1])  # B重新开始记录起止位置，所以删除前面的元素\n",
    "                    # entity = ''\n",
    "                    entity = sent[i][0]\n",
    "                elif y_test[i][:1] == 'I':  # 代表当前token仍然属于同一实体\n",
    "                    if str(sent[i][0]) == '\\n':\n",
    "                        entity += ' '\n",
    "                        start_end_pairs.append([start_end_pairs[-1][1] + 1, start_end_pairs[-1][1]])\n",
    "                    else:\n",
    "                        entity += str(sent[i][0])\n",
    "            elif flag:\n",
    "                flag = False\n",
    "                id += 1\n",
    "                write_format(f, id, y_test[i - 1][2:], start_end_pairs, entity)\n",
    "            start_end_pairs[-1][1] += len(str(sent[i][0]))\n",
    "        f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
